import itertools
from collections import defaultdict
from typing import Dict, Union, Set, List

from ..utils import func


def _roommates_check(crawl_data: dict) -> Dict[str, list]:
    """checks for IP addresses with too much different nodes"""
    d = defaultdict(list)
    for node in crawl_data["up"]:
        d[node["IP address"]].append(node["UDP port"])
    return d


def _ip_aliases_check(crawl_data: dict) -> Dict[str, Dict[str, Set[Union[str, int]]]]:
    """checks for IP addresses with too much different nodes"""
    d = defaultdict(lambda: defaultdict(set))
    for node in crawl_data["up"]:
        d[node["IP address"]]["addresses"].update(node['Seen node IDs'])
        d[node["IP address"]]["ports"].add(node["UDP port"])
    return d


def _subnet_check(crawl_data: dict) -> Dict[str, Set[str]]:
    """checks for IP addresses on the same /24 subnet"""
    d = defaultdict(set)
    for node in crawl_data["up"]:
        d[".".join(node["IP address"].split(".")[:3])].add((node["IP address"], node["UDP port"]))
    return d


def _concentration_check(crawl_data: dict, threshold: int, group_size_threshold: int = 1) -> List[
    Dict[str, Union[int, dict]]]:
    """checks for grouped node addresses"""

    values = sorted(itertools.chain(
        *[[(identifier, node["IP address"], node["UDP port"]) for identifier in
           node['Seen node IDs']]
          for node in crawl_data["up"]]), key=lambda x: (x[0], tuple(map(lambda y: int(y), x[1].split('.')))))

    if threshold is None:
        threshold = 0
        while 2 ** threshold < len(values):
            threshold += 1

    output = list()
    group = list()
    group_prefix = 512
    flag = False

    for pref, a, b in map(func, zip(values, values[1:])):
        if pref >= threshold:
            if pref != 512:
                group.append({"a": {"id": a[0], "ip": a[1], "udp": a[2]},
                              "b": {"id": b[0], "ip": b[1], "udp": b[2]},
                              "prefix": pref})
                group_prefix = min(group_prefix, pref)
                flag = True
        elif flag:
            if len(group) >= group_size_threshold + 1:
                output.append({"count": len(group) + 1, "max_shared_prefix": group_prefix, "ids": group})
            flag = False
            group = list()
            group_prefix = 512
    if flag and len(group) >= group_size_threshold + 1:
        output.append({"count": len(group) + 1, "max_shared_prefix": group_prefix, "ids": group})
    return output


def _identical_check(crawl_data: dict) -> Dict[str, list]:
    """checks for nodes sharing the same id"""

    d = defaultdict(list)
    for v in sorted(itertools.chain(*[
        [{"id": identifier, "ip": node["IP address"], "udp": node["UDP port"]} for
         identifier in node['Seen node IDs']] for node in crawl_data["up"]]),
                    key=lambda x: (x["id"], tuple(map(lambda y: int(y), x["ip"].split('.'))))):
        d[v["id"]].append({k: v[k] for k in list(v.keys())[1:]})
    return d
